from os.path import basename
from sys import displayhook, platform
from datetime import datetime, timedelta
from pandas import DataFrame

import argparse
import matplotlib.pyplot as plt1
import matplotlib.pyplot as plt2
import pandas as pd
import numpy as np
import re

if platform == "sunos5":
    import matplotlib

    matplotlib.use('Agg')
else:
    import seaborn as sns


def get_color_palette(size):
    return sns.hls_palette(size - 1, h=.5) if platform != 'sunos5' else None


def get_daemons(dataframe):
    """
    Get the list of Wazuh Daemons
    :params dataframe: CPU, Memory and FD dataset
    """
    return dataframe.Daemon.unique()


def get_subdataframe_from_columns(dataframe, columns):
    """
    Return subdataframe of given dataframe and columns of such one.
    :param dataframe: pandas dataframe with CSV data.
    :param columns: list with columns to extract a subdataframe.
    """
    subdataframe = pd.DataFrame()
    for column in columns:
        subdataframe[column] = list(dataframe[column])

    return subdataframe



def split_analysisd_df(dataframe, platform_data):
    """
    Split the CSV generated by the wazuh-analysisd.state
    in several CSV files containing the cumulative data
    (events generated) and non-cumulative data (queue usage)
    :params dataframe: pandas dataframe
    """

    def split_df(df_orig, columns):
        df_dest = pd.DataFrame()

        for column in columns:
            df_dest[column] = list(df_orig[column])

        return df_dest

    non_cumulative_data = ["Timestamp", "Syscheck queue", "Syscollector queue", "Rootcheck queue", "SCA queue",
                           "Hostinfo queue", "Winevt queue", "Event queue", "Rule matching queue", "Alerts log queue",
                           "Firewall log queue", "Statistical log queue", "Archives log queue"]

    cumulative_data = ["Timestamp", "Total Events", "Syscheck Events Decoded", "Syscollector Events Decoded",
                       "Rootcheck Events Decoded", "SCA Events Decoded", "HostInfo Events Decoded",
                       "WinEvt Events Decoded", "Other Events Decoded", "Events processed (Rule matching)",
                       "Events received", "Events dropped", "Alerts written", "Firewall alerts written",
                       "FTS alerts written"]

    edps_data = ["Timestamp", "Syscheck EDPS", "Syscollector EDPS", "Rootcheck EDPS",
                 "SCA EDPS", "HostInfo EDPS", "WinEvt EDPS", "Other EDPS", "Events EDPS (Rule matching)"]

    non_cumulative_df = split_df(dataframe, non_cumulative_data)
    cumulative_df = split_df(dataframe, cumulative_data)
    edps_df = split_df(dataframe, edps_data)
    return non_cumulative_df, cumulative_df, edps_df


def get_subdataframes(dataframe, platform_data):
    """
    Get 4 new dataframes containing the memory, cpu,
    vms and file descriptors data separately
    :params dataframe: complete dataset
    :param platform_data: Platform from which the csv data were obtained
    """

    def build_subdf(df, sub_df, daemon, tag):
        df[daemon], df['Timestamp'] = list(sub_df[tag]), list(sub_df['Timestamp'])

    cpu_df, memory_df  = pd.DataFrame(), pd.DataFrame()
    vms_memory_df, file_df = pd.DataFrame(), pd.DataFrame()
    read_ops_df, write_ops_df = pd.DataFrame(), pd.DataFrame()
    bytes_read_df, bytes_written_df = pd.DataFrame(), pd.DataFrame()
    uss_df, pss_df = pd.DataFrame(), pd.DataFrame()
    swap_df, disk_usage_df = pd.DataFrame(),  pd.DataFrame()

    for daemon in get_daemons(dataframe):
        sub_df = dataframe[(dataframe.Daemon == daemon)]
        build_subdf(cpu_df, sub_df, daemon, 'CPU(%)')
        build_subdf(memory_df, sub_df, daemon, 'RSS(KB)')
        build_subdf(vms_memory_df, sub_df, daemon, 'VMS(KB)')

        if platform_data != 'macos' and platform_data != 'solaris':
            build_subdf(uss_df, sub_df, daemon, 'USS(KB)')

        build_subdf(file_df, sub_df, daemon, 'FD')

        if platform_data != 'macos' and platform_data != 'solaris':
            build_subdf(read_ops_df, sub_df, daemon, 'Read_Ops')
            build_subdf(write_ops_df, sub_df, daemon, 'Write_Ops')
            build_subdf(bytes_read_df, sub_df, daemon, 'Disk_Read(B)')
            build_subdf(bytes_written_df, sub_df, daemon, 'Disk_Written(B)')
            build_subdf(disk_usage_df, sub_df, daemon, 'Disk(%)')

        if platform_data == 'linux':
            build_subdf(pss_df, sub_df, daemon, 'PSS(KB)')
            build_subdf(swap_df, sub_df, daemon, 'SWAP(KB)')



        if platform_data != 'macos' and platform_data != 'solaris':
            dataframes = [cpu_df, memory_df, vms_memory_df, file_df,
                        read_ops_df, write_ops_df, bytes_read_df,
                        bytes_written_df, disk_usage_df,
                        uss_df]
        else:
            dataframes = [cpu_df, memory_df, vms_memory_df, file_df]


    if platform_data == 'linux':
        dataframes += [pss_df, swap_df]

    return dataframes


def get_sync_subdataframes(dataframe, platform_data):
    """
    Get 2 new dataframes containing the network
    usage for each agent separately
    :params dataframe: complete dataset
    """
    send_df_merged = pd.DataFrame(columns=['Timestamp'])
    recv_df_merged = pd.DataFrame(columns=['Timestamp'])

    for agent in dataframe.Agent.unique():
        sub_df = dataframe[(dataframe.Agent == agent)]

        send_df_merged = send_df_merged.merge(sub_df[['Bytes_sent', 'Timestamp']],
                                              on='Timestamp', how='outer')
        recv_df_merged = recv_df_merged.merge(sub_df[['Bytes_received', 'Timestamp']],
                                              on='Timestamp', how='outer')

    return send_df_merged, recv_df_merged


def find_maxmin(MAX, MIN, timestamp, df, daemon, interval):
    """
    Get a list of maximum and minimum values in intervals of x minutes
    :params MAX: list to store maximuns values
    :params MIN: list to store minimuns values
    :params timestamp: list to store timestamp values of interval
    :params df: DataFrame to process
    :params daemon: Daemon associated to values
    :params interval: Sampling interval
    """

    first_timestamp = int((df["Timestamp"][0]).split(':')[1])
    second_timestamp = first_timestamp + int(interval)
    if (second_timestamp >= 60): second_timestamp = second_timestamp % 60

    max = df[daemon][0]
    min = max

    MAX.append(max)
    MIN.append(min)
    timestamp.append(df['Timestamp'][0])

    for value, time in zip(df[daemon], df["Timestamp"]):

        first_timestamp = int(time.split(':')[1]) % 60

        if (first_timestamp != second_timestamp):

            if value > max:
                max = value

            if value < min:
                min = value

        else:

            MAX.append(max)
            MIN.append(min)
            timestamp.append(time)

            max = value
            min = value

            second_timestamp = first_timestamp + int(interval)
            if (second_timestamp >= 60): second_timestamp %= 60


def getNextTimestamp(timestamp_pos, timestamp, df, is_upper_limit):
    """
    Get the nearest timestamp of a specified timestamp
    :params timestamp_pos: Nearest timestamp position to a certain timestamp
    :params timestamp: Timestamp reference from where the search starts
    :params df: Dataframe with the Timestamp data
    :params is_upper_limit: [True or False] Specifies whether the reference timestamp is the upper limit or not
    """
    # Time in seconds to apply to timestamp
    value = -1 if is_upper_limit else 1
    # Counter for the worst case if timestamp is not found (7 minutes range)
    limit = 420
    # Find the position of the next timestamp
    while (timestamp_pos[0].size == 0) and (limit != 0):
        timestamp = (datetime.strptime(str(timestamp), "%Y-%m-%d %H:%M:%S")) + timedelta(seconds=value)
        timestamp_pos = np.where(df['Timestamp'] == str(timestamp))
        limit -= 1
    return timestamp_pos


def plot_df(data, start_range, end_range, store_name, data_tag, is_MAXMIN, labels=None, title=None, cformat='png'):
    """
    Plot the data from the dataframe.
    :param data: dataframe with all the data set
    :params start_range: TimeStamp that reference the beginning of the stress period
    :params end_range: TimeStamp that reference the end of the stress period
    :param store_name: basename to store the plot.
    :param data_tag: tag used as additional info of the plot and to save the plot.
    :param is_MAXMIN: Specifies whether the MAXMIN_ tag is added or not
    :param labels: optional list of labels to add to the plot's legend
    :params title: Plot title
    :params cformat: Custom file format
    """
    plt1.style.use('ggplot')

    tag = ""
    total_size, step_size = data.shape[0], data.shape[0] // 100

    if step_size == 0: xticks = range(0, total_size)
    else: xticks = range(0, total_size, step_size)

    color_palette = get_color_palette(data.shape[1])
    data_tag = re.sub(r'\(\%\)|\(KB\)|\(B\)', '', data_tag)

    if is_MAXMIN:
        tag = "_MAXMIN"

    data.plot(x='Timestamp', rot=90, figsize=(26, 9), xticks=xticks,
        color=color_palette, title=title, linewidth=1.5).set_ylabel(data_tag + tag)

    start, end = [], []
    if start_range != "-" and end_range != "-":

        start_range = start_range.replace('_', ' ')
        end_range = end_range.replace('_', ' ')

        start = np.where(data['Timestamp'] == start_range)
        end = np.where(data['Timestamp'] == end_range)

        start = getNextTimestamp(start, start_range, data, False)
        end = getNextTimestamp(end, end_range, data, True)

        if start[0].size != 0 and end[0].size != 0:
            last = data['Timestamp'][data.index[-1]]
            last_index = np.where(data['Timestamp'] == last)

            idle = plt2.axvspan(0, start[0], alpha=0.05, color="green")
            stress = plt2.axvspan(start[0], end[0], alpha=0.05, color="red")
            unstress = plt2.axvspan(end[0], last_index[0], alpha=0.05, color="yellow")
            sections = []
            sections.append([idle, stress, unstress])
            leg = plt2.legend(sections[0], ["Resources allocation","Full load","Resources release"], loc="lower left",bbox_to_anchor=(1.0, 0.7))
            plt1.gca().add_artist(leg)

    plt1.legend(labels=labels, loc='upper left', bbox_to_anchor=(1.0, 0.7))
    plt1.tight_layout()
    plt1.xlim([0,total_size])
    output_tag = data_tag + tag
    plt1.savefig("{0}_{1}.{2}".format(store_name, output_tag, cformat), dpi='figure', format=cformat)


def plot(csv_filename, store_name, target, platform_data, title, cformat, start_range, end_range, interval):
    """
    Read a generic CSV and generate some graphics depending on
    the target to plot
    :param csv_filename: string with the name of the CSV file.
    :param store_name: basename to store the plot.
    :param target: which data vis must perform the function.
    :param platform_data: Platform from which the csv data were obtained
    :params title: Plot title
    :params cformat: Custom file format
    :params start_range: TimeStamp that reference the beginning of the stress period
    :params end_range: TimeStamp that reference the end of the stress period
    :params interval: Sampling interval
    """

    def split_and_plot(function, csv_name, data_tags, platform_data, title, cformat, target=None):
        if type(csv_name) is list:
            df = pd.concat((pd.read_csv(f) for f in csv_filename))
            labels = ['Agent_{}'.format(agent_id) for agent_id in df.Agent.unique()]
        else:
            df = pd.read_csv(csv_filename)
            if target == "binaries":
                last_timestamp = df['Timestamp'].iloc[-1]
                pos = None
                for index in reversed(df.index):
                    if last_timestamp != df['Timestamp'][index]:
                        pos = index+1
                        break
                if pos != None: df = df.iloc[:pos]
            labels = None

        if function != "-":
            for data, data_tag in zip(function(df, platform_data), data_tags):
                plot_df(data, start_range, end_range, store_name, data_tag, False, labels, title, cformat)

        else:
            daemon_list = []
            memory_df  = pd.DataFrame()
            for daemon in get_daemons(df):
                sub_df = df[(df.Daemon == daemon)]
                daemon_list.append(daemon)
                memory_df[daemon], memory_df['Timestamp'] = list(sub_df[data_tags]), list(sub_df['Timestamp'])
            data = DataFrame()
            for daemon in daemon_list:
                MAX = []
                MIN = []
                timestamp = []
                find_maxmin(MAX, MIN, timestamp, memory_df, daemon, interval)

                dae_max = daemon + "_MAX"
                dae_min = daemon + "_MIN"

                data["Timestamp"] = timestamp
                data[dae_max] = MAX
                data[dae_min] = MIN

            plot_df(data, start_range, end_range, store_name, data_tags, True, labels, title, cformat)

    if target == "sync":
        plot_names = ["Bytes_sent", "Bytes_received"]
        split_and_plot(get_sync_subdataframes, csv_filename, plot_names, platform_data, title, cformat)

    elif target == "binaries":

        plot_names = ["CPU(%)", "RSS(KB)", "VMS(KB)", "FD"]

        if platform_data != 'macos' and platform_data != 'solaris':
            plot_names += ["Read_Ops", "Write_Ops", "Disk_Read(B)",
                    "Disk_Written(B)", "Disk(%)", "USS(KB)"]

        if platform_data == 'windows':
            plot_names[3] = "Handles"

        if platform_data == 'linux':
            plot_names += ["PSS(KB)", "SWAP(KB)"]

        split_and_plot(get_subdataframes, csv_filename, plot_names, platform_data, title, cformat, target) # Plot generic graphs

        # If time range < 5 minutes, dont plot max min
        start_date, end_date = start_range[0:10], end_range[0:10]
        plot_name_RSS = "RSS(KB)"
        if start_date == end_date or start_date != end_date:
            split_and_plot("-", csv_filename, plot_name_RSS, platform_data, title, cformat) # Plot max min RSS
        else:
            # Get time difference
            start_time, end_time = start_range[11:], end_range[11:]
            tdelta = datetime.strptime(end_time, '%H:%M:%S') - datetime.strptime(start_time, '%H:%M:%S')
            if tdelta.total_seconds()/60 >= 5:
                split_and_plot("-", csv_filename, plot_name_RSS, platform_data, title, cformat) # Plot max min RSS

    elif target == "analysisd":
        plot_names = ["Queues_state", "Number_Events", "EDPS"]
        split_and_plot(split_analysisd_df, csv_filename, plot_names, platform_data, title, cformat)

    elif target == "remoted":
        raise NotImplementedError

    elif target == "agentd":
        dataframe = pd.read_csv(csv_filename)
        target_charts = ["Number of events buffered", "Number of messages", "Number of generated events", "Status"]
        for chart in target_charts:
            if chart in dataframe:
                name = 'AgentD_' + chart.replace(' ', '_')
                plot_names = ["Timestamp", chart]
                subdataframe = get_subdataframe_from_columns(dataframe, plot_names)
                if 'Status' in subdataframe:
                    subdataframe["Status"] = subdataframe["Status"].replace("connected", 1)
                    subdataframe["Status"] = subdataframe["Status"].replace("disconnected", 0)
                    subdataframe["Status"] = subdataframe["Status"].replace("pending", 0.5)
                plot_df(subdataframe, start_range, end_range, store_name, name, False, None, title, cformat)


def main():

    interval = "5"
    start_range, end_range = "-", "-"

    parser = argparse.ArgumentParser()
    ###########################################################################
    parser.add_argument('-n', '--name', help="Basename of the images", type=str, dest='name')
    parser.add_argument('-p', '--path', help="Path to store the images", type=str, dest='path')
    parser.add_argument('-s', '--source', help="CSV files with the dataset", type=str, dest='source')
    parser.add_argument('-sf', '--sync_files', nargs='+',
        help="CSV files with the dataset", type=str,
        dest='source_list')
    parser.add_argument('-t', '--target',
        help="Target plot of the script: binaries, analysisd, remoted, agentd, sync",
        type=str, dest='target')
    parser.add_argument('-o', '--platform',
        help='Platform from which the csv data were obtained', type=str,
        dest='platform_data', default='linux')
    parser.add_argument('--title',
        help="Title of the generated chart, add extra info here.",
        type=str, dest='title')
    parser.add_argument('--format', help="Output format of charts.",
        type=str, dest='format', default='png')
    parser.add_argument('-i', '--interval', help='Time interval between samples', type=str, dest='interval')
    parser.add_argument('-sr',  '--start', help='Start of data range. Format example: 2021-07-02_20:56:31', type=str, dest='start_range')
    parser.add_argument('-er',  '--end', help='End of data range. Format example: 2021-07-02_20:56:31', type=str, dest='end_range')
    args = parser.parse_args()

    source_file = args.source if args.source is not None else args.source_list
    start_range = args.start_range if args.start_range is not None else start_range
    end_range = args.end_range if args.end_range is not None else end_range
    interval = args.interval if args.interval is not None else interval
   
    if (start_range != "-" and end_range == "-") or (start_range == "-" and end_range != "-"):
        print("Parameters -sr and -er must be used together if any of them are defined")
        exit(1)

    if start_range != "-" and end_range != "-":
        start_time, end_time = start_range.replace('_', ' '), end_range.replace('_', ' ')
        tdelta = datetime.strptime(end_time, '%Y-%m-%d %H:%M:%S') - datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S')
        if tdelta.total_seconds()/60 < 0:
            print("The -er parameter must be greater than the -sr parameter")
            exit(1)

    if not args.name or not args.path or not source_file:
        print("name, path and source paramaters are needed.")
        exit(1)
    else:
        plot(source_file, "{0}/{1}".format(args.path, basename(args.name)), args.target, args.platform_data, args.title, args.format, start_range, end_range, interval)

if __name__ == "__main__":
    main()
